from hyperopt import hp, fmin, tpe, Trials
import subprocess

# 参数表
params = [
    ['PLAN_STEPS',range(10, 40+1, 1), 1],
    ['GOODS_LIFE_FACTOR',range(0, 400+1, 1), 1],
    ['ROUND_APPEND_FACTOR',range(0, 10+1), 0.1],
    ['FREE_AWAY_DISTANCE',range(0, 10+1), 1],
    ['ROBOT_MOVE_MARGIN_TIME',range(0, 20+1), 1],
]

def change_file_lines(filepath: str, map:int, content: str):
    """改变指定地图的参数"""
    lines = []
    with open(filepath, 'r+') as file:
        lines = file.readlines()
    is_start = False
    startline, endline = '', ''
    for i, line in enumerate(lines):
        if f'// autoset{map}' in line:
            startline = line
            start = i+1
            is_start = True
        if is_start and f'// end autoset{map}' in line:
            endline = line
            end = i+1
            break
    else:
        print("未找到相应代码位置！")
        exit(1)
    new_lines = lines[:start-1]
    new_lines.append(startline)
    new_lines.append(content)
    new_lines.append(endline)
    new_lines.extend(lines[end:])
    with open(filepath, 'w+') as file:
        file.writelines(new_lines)

def target(x, codefile, map, scripts:str, sand=123):
    """目标函数"""
    param_text = ''
    for i, param in enumerate(params):
        param_text += ' '*12 + f'{param[0]} = ' + (f'{x[i] * param[2]:.1f}' if type(param[2]) == float else f'{x[i] * param[2]}') + ';\n'
    change_file_lines(codefile, map, param_text)
    if (scripts):
        subprocess.run(['./zip.sh', scripts])
        ret = int(input(f'参数表已更新，提交版本已打包为[{scripts}.zip]，请上传并输入得分\n'))
    else:
        subprocess.run('./build.sh', stdout=subprocess.PIPE)
        ret = subprocess.run(['../sdk/PreliminaryJudge',
                              '-m', f'maps/map{map}.txt',
                              '-r', f'map{map}.rep',
                              '-l', 'NONE',
                              '-s', f'{sand}',
                              f'"./main"'], stdout=subprocess.PIPE)
        ret = int(str(ret.stdout).strip("'").strip('\\n').strip('}').split(':')[-1])
    return -ret

if __name__ == '__main__':
    codefile = 'Hyperparameters.hpp'
    map = int(input('地图号：'))
    scripts = input('请输入自动打包版本号开头，或直接回车，进入本地调试模式\n')
    if scripts == '':
        rand = int(input('请指定随机种子\n'))
    else:
        rand = 123
    model = Trials()
    rounds = 1
    bestParams = None
    while True:
        best_params = fmin(
            fn=lambda x: target(x, codefile, map, scripts+f'.{rounds}' if scripts else None, rand),
            space=[hp.choice(param[0], param[1]) for param in params],
            algo=tpe.suggest,
            max_evals=rounds,
            trials=model,
            show_progressbar=False
        )
        param_text = ''
        for param in params:
            param_text += ' '*12 + f'{param[0]} = ' + (f'{param[1][best_params[param[0]]] * param[2]:.1f}' if type(param[2]) == float else f'{param[1][best_params[param[0]]] * param[2]}') + ';\n'
        input_text = input(f"已尝试[{rounds}]轮，当前得分[{-model.losses()[-1]:.0f}]，最高得分[{-model.best_trial['result']['loss']:6.0f}]，继续吗？(按回车继续)\n")
        if input_text != '':
            break
        rounds += 1
    change_file_lines(codefile, map, param_text)
    print(f"一共试了[{rounds}]轮，最高得分[{-model.best_trial['result']['loss']:6.0f}]，最优参数已更新到代码文件")